# hcquant

import pandas as pd
from dateutil.parser import parse
import tushare as ts
from fastcache import lru_cache
from abc import ABCMeta, abstractmethod
from ..utils.engine import calendar_engine
import datetime as dt
today = dt.datetime.now().strftime('%Y%m%d')

# 功能性函数
def dt_handle(date):
    try:
        return pd.to_datetime(date).strftime('%Y%m%d')
    except ValueError:
        return date

# 交易日获取函数
def get_tradeday():
    sql="""
        select TRADE_DAYS as calendarDate from AShareCalendar where S_INFO_EXCHMARKET='SSE' and TRADE_DAYS<='{0}'  order by TRADE_DAYS
    """.format(today)
    df = pd.read_sql(sql, calendar_engine)
    return df


# 交易日调整函数，向前向后移动
def step_trade_date(date, step, astype='%Y%m%d'):
    """
    交易日向前向后漂移
    返回 '%Y-%m-%d'
    """
    date = parse(date).strftime('%Y-%m-%d') if isinstance(date, str) else date.strftime('%Y-%m-%d')
    if step == 0:
        return parse(date).strftime(astype)
    else:
        df = get_tradeday()
        date = dt_handle(date)
        if step > 0:
            df = df.loc[df['calendarDate'] > date]
            return parse(df['calendarDate'].iloc[step-1]).strftime(astype)
        elif step < 0:
            df = df.loc[df['calendarDate'] < date]
            return parse(df['calendarDate'].iloc[step]).strftime(astype)


# 调整为最近的交易日
def adjust_to_trade_date(date, adjust='last', astype='%Y%m%d'):
    """
    调整基准日，若基准日不是交易日，则将其调整为之前最近的一个交易日
    adjust: last or next
    """
    date = parse(date).strftime('%Y-%m-%d') if isinstance(date, str) else date.strftime('%Y-%m-%d')
    df = get_tradeday()
    date = dt_handle(date)
    if adjust == 'last':
        df = df.loc[df['calendarDate'] <= date]
        return parse(df['calendarDate'].iloc[-1]).strftime(astype)
    elif adjust == 'next':
        df = df.loc[df['calendarDate'] >= date]
        return parse(df['calendarDate'].iloc[0]).strftime(astype)


# 给定开始日期和结束日期，返回其中的所有交易日
def get_time_list(start_date, end_date, astype='list'):
    """
    给定开始日期和结束日期，返回其中的所有交易日
    Args:
        start_date (str or datetime): 开始日期，常见日期格式支持
        end_date (str ot datetime): 结束日期，常见日期格式支持
    Returns:
        list: str,‘%Y-%m-%d’格式的日期列表

    Examples:
        >> time_list = get_time_list('20150101','20160101')
    """
    start_date = parse(start_date).strftime('%Y-%m-%d') if isinstance(start_date, str) else start_date.strftime('%Y-%m-%d')
    end_date = parse(end_date).strftime('%Y-%m-%d') if isinstance(end_date, str) else end_date.strftime('%Y-%m-%d')
    start_date = dt_handle(start_date)
    end_date = dt_handle(end_date)
    df = get_tradeday()
    _list = df.loc[(df['calendarDate'] >= start_date) & (df['calendarDate'] <= end_date)]
    _list.rename(columns={'calendarDate':'trade_dt'},inplace=True)
    if astype == 'pd':
        return _list
    elif astype == 'list':
        return _list['trade_dt'].tolist()
    else:
        raise ValueError(f"astype:{astype} must be 'pd' or 'list'!")


# 获取所有的节假日
def get_holidays(astype='%Y%m%d'):
    df = get_tradeday()
    tradedate_set = set(df['calendarDate'])
    business_set = set(pd.bdate_range('1990-12-19', df['calendarDate'].iloc[-1]).strftime('%Y%m%d'))
    holidays_list = list(business_set - tradedate_set)
    holidays_list.sort()
    return [parse(x).strftime(astype) for x in holidays_list]


# if __name__ == "__main__":
#     import datetime as dt
#     print(step_trade_date('20180607', 1))
#     # 注意20171231为非交易日，自动视为20180102
#     print(step_trade_date('20171231', 1))
#     print(step_trade_date('20171231', 100))
#     print(step_trade_date('20171231', -1000))







